Skip to content

Conversation

@am17an
Copy link
Collaborator

@am17an am17an commented Oct 28, 2025

Found this bug while looking at a related issue. When using ggml_can_fuse_subgraph, the output nodes which are passed are wrong. This causes test-backend-ops to still fuse nodes (because the nodes are not used elsewhere in the graph), but it actually doesn't fuse in the actual gpt-oss model

With this change

model size params backend ngl test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg32 193.33 ± 0.99
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg64 189.30 ± 0.55
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg128 186.83 ± 0.34

Current master

model size params backend ngl test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg32 184.52 ± 1.00
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg64 180.36 ± 0.50
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg128 178.16 ± 0.53

Edit:
Looks the bug was qwen3 as well after the adding the clamp (#16702)

With this change

model size params backend ngl test t/s
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CUDA 99 tg32 162.29 ± 0.41
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CUDA 99 tg64 158.72 ± 0.64
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CUDA 99 tg128 156.83 ± 0.22

Current Master

model size params backend ngl test t/s
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CUDA 99 tg32 157.21 ± 0.36
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CUDA 99 tg64 153.71 ± 0.41
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CUDA 99 tg128 151.85 ± 0.15

When using ggml_can_fuse_subgraph, the output nodes which are passed are wrong. This causes `test-backend-ops` to still fuse ndoes (because the nodes are not used elsewhere in the graph),
but it actually doesn't fuse in the actual gpt-oss
@ggerganov
Copy link
Member

ggerganov commented Oct 28, 2025

Did a quick bench on the DGX Spark:

GGML_CUDA=ON ./scripts/compare-commits.sh master pr/16821 llama-bench -m ./models/gpt-oss-20b/ggml-model-mxfp4.gguf -m models/gpt-oss-120b/ggml-model-mxfp4-00001-of-00003.gguf -fa 1 -d 0,4096,8192 -p 2048 -n 32 -ub 2048 -mmp 0
Model Test t/s master t/s pr/16821 Speedup
gpt-oss 120B MXFP4 MoE pp2048 1850.41 1839.89 0.99
gpt-oss 120B MXFP4 MoE pp2048@d4096 1724.87 1788.50 1.04
gpt-oss 120B MXFP4 MoE pp2048@d8192 1647.51 1705.67 1.04
gpt-oss 120B MXFP4 MoE tg32 55.01 55.80 1.01
gpt-oss 120B MXFP4 MoE tg32@d4096 51.32 52.10 1.02
gpt-oss 120B MXFP4 MoE tg32@d8192 48.21 48.31 1.00
gpt-oss 20B MXFP4 MoE pp2048 3601.26 3553.66 0.99
gpt-oss 20B MXFP4 MoE pp2048@d4096 3350.17 3290.34 0.98
gpt-oss 20B MXFP4 MoE pp2048@d8192 3119.52 3087.23 0.99
gpt-oss 20B MXFP4 MoE tg32 77.43 77.48 1.00
gpt-oss 20B MXFP4 MoE tg32@d4096 72.54 72.30 1.00
gpt-oss 20B MXFP4 MoE tg32@d8192 68.69 69.10 1.01

Taking into account the uncertainty of the numbers, it's hard to say there is a significant difference. I think your earlier observations that fusion is less effective at lower memory bandwidths might be correct.

@am17an
Copy link
Collaborator Author

am17an commented Oct 28, 2025

@ggerganov this PR would only add 4-5%, what I was this along with #16715 would be around 10% in TG.

Edit: looks like it doesn't help at all in tg, so yes could be the memory thing

@ggerganov
Copy link
Member

Yes, here is comparing one commit right before #16715 and this PR:

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GB10, compute capability 12.1, VMM: yes

model size params backend ngl n_ubatch fa mmap test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 pp2048 3591.02 ± 17.98
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 tg32 77.82 ± 0.51
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 pp2048 @ d4096 3349.24 ± 11.75
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 tg32 @ d4096 72.66 ± 0.90

build: 3cfa9c3 (6840)

model size params backend ngl n_ubatch fa mmap test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 pp2048 3578.68 ± 6.71
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 tg32 77.75 ± 0.26
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 pp2048 @ d4096 3339.55 ± 12.86
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 2048 1 0 tg32 @ d4096 72.88 ± 0.11

build: 91cd070 (6867)

Pretty much no difference on the DGX Spark.

@JohannesGaessler
Copy link
Collaborator

In principle, if you have more memory bandwidth then the percentage of the runtime going towards loading weights is comparatively low while the percentage going towards overhead and other I/O is high. So in that regard it makes sense that fusion would be more beneficial for high memory bandwidth GPUs. If you look at the numbers I posted in #16715 (comment) there were some AMD GPUs with comparatively low memory bandwidth that benefited from the fusion but that could also be because the overhead (in my experience) is higher with AMD.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 28, 2025
@am17an am17an merged commit 9a3ea68 into ggml-org:master Oct 29, 2025
73 of 83 checks passed
@am17an am17an deleted the fix_topk_cuda branch October 29, 2025 07:55
wqerrewetw added a commit to wqerrewetw/llama.cpp that referenced this pull request Oct 29, 2025
* Ci (#11) (#12)

* Fix cl (#7)

* Rename build-amd.yml to build-amd.yml.disabled

* Rename winget.yml to winget.yml.disabled

* Rename server.yml to server.yml.disabled

* Rename build.yml to build.yml.disabled

* Update release.yml

* Rename build-cmake-pkg.yml to build-cmake-pkg.yml.disabled

* Rename build-linux-cross.yml to build-linux-cross.yml.disabled

* Rename build-riscv-native.yml.disabled to build-riscv-native.yml

* Rename docker.yml.disabled to docker.yml

* Rename update-ops-docs.yml to update-ops-docs.yml.disabled

* Remove macOS-arm64 job from release workflow

Removed macOS-arm64 job and its associated steps from the release workflow.

* CUDA: Fix bug in topk-moe for gpt-oss (ggml-org#16821)

* CUDA: Fix bug in topk-moe for gpt-oss

When using ggml_can_fuse_subgraph, the output nodes which are passed are wrong. This causes `test-backend-ops` to still fuse ndoes (because the nodes are not used elsewhere in the graph),
but it actually doesn't fuse in the actual gpt-oss

* fix for qwen3 too

* change ifndef to ifdef

* vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy (ggml-org#16793)

This lets the copy to the destination device use the host-visible
vidmem optimization.

---------

Co-authored-by: Aman Gupta <[email protected]>
Co-authored-by: Jeff Bolz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants